Recreating branching model animations

Artem Kirsanov, August 2023

import numpy as np
from copy import deepcopy
import itertools
import matplotlib
import matplotlib.pyplot as plt
from numpy import radians as rad
from matplotlib.animation import FuncAnimation
from scipy.ndimage import convolve,convolve1d
import cmasher
import seaborn as sns

Matplotlib

NUM_LAYERS = 20
NEURONS_PER_LAYER = 10
def network_init():
    return np.zeros((NUM_LAYERS, NEURONS_PER_LAYER),dtype=bool)


def network_advance(old_network, sigma,spont_prob):
    '''Advance one time step'''
    network = deepcopy(old_network)
    spont = np.random.rand(*network.shape)
    network[spont<spont_prob] = 1 # Random spontaneous activity
    for layer_num in range(NUM_LAYERS-1, 0, -1):
        # Randomly propagate, starting from the last layer
        propagation_mask = np.random.rand(NEURONS_PER_LAYER) < sigma*np.sum(network[layer_num-1,:])/NEURONS_PER_LAYER
        network[layer_num] = propagation_mask
        network[layer_num-1] = np.zeros(NEURONS_PER_LAYER)
    return network


def run_simulation(network, n_steps, sigma=1, spont_prob=0.01):
    '''Run simulation with stochastic activity for n_steps'''
    network_states = np.zeros((n_steps, NUM_LAYERS, NEURONS_PER_LAYER))
    network_states[0,:,:] = network

    for step in range(1,n_steps):
        network_states[step, :,:] = network_advance(network_states[step-1, :,:], sigma,spont_prob)
    return network_states
network = network_init()
evolution = run_simulation(network, 50, sigma=1, spont_prob=0.01)
fig, ax = plt.subplots(1,1,figsize=(10,5),dpi=200)
ax.axis(False)
fig.set_facecolor("white")
ax.set_facecolor("white")

cmesh = ax.pcolormesh(evolution[0,:,:].T, edgecolors='white', vmin=0, vmax=1,linewidth=2, cmap=plt.cm.coolwarm)

def anim_function(frame_num):
    cmesh.set_array(evolution[frame_num,:,:].T)
    return cmesh,

anim = FuncAnimation(fig, anim_function, frames=np.arange(evolution.shape[0]), interval=30)
anim.save("Network evolution raw fast.mp4")
[01/29/26 13:38:31] INFO     Animation.save using <class 'matplotlib.animation.FFMpegWriter'>     animation.py:1076
                    INFO     MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec         animation.py:319
                             rawvideo -s 2000x1000 -pix_fmt rgba -framerate 33.333333333333336                     
                             -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y 'Network                    
                             evolution raw fast.mp4'                                                               

# --- Running the model
network = network_init()
evolution = run_simulation(network, 500, sigma=1, spont_prob=0.01)

# --- Smoothing activity
def smooth_activity(network_states, time_stretch=3):
    '''
        Smooth the activity in time for a more eye-pleasant animation
    '''
    def get_symmetric_kernel(slope=-20, npoints=100):
        t = np.linspace(0,1,npoints)
        kernel = np.zeros_like(t)
        t_mask = t>0.5
        kernel[t_mask]=np.exp(slope*t[t_mask])
        kernel[(t<=0.5)]=np.exp(slope*t[t_mask])[::-1]
        return kernel/kernel[t_mask][0]

    kernel = get_symmetric_kernel(-60)
    smoothed_activity = np.zeros((network_states.shape[0]*time_stretch, network_states.shape[1], network_states.shape[2]))
    smoothed_activity[::time_stretch, :, :] = network_states
    smoothed_activity = convolve1d(smoothed_activity, kernel, axis=0,mode="constant",origin=0)
    return smoothed_activity
# --- Animation
smoothed_evolution = smooth_activity(evolution)
fig, ax = plt.subplots(1,1,figsize=(10,5),dpi=200)
ax.axis(False)
fig.set_facecolor("white")
ax.set_facecolor("white")

cmap = cmasher.get_sub_cmap(sns.color_palette("mako",as_cmap=True),0.2,1)
cmesh = ax.pcolormesh(smoothed_evolution[0,:,:].T, edgecolors='white', vmin=0, vmax=1,linewidth=2, cmap=cmap)

def anim_function(frame_num):
    cmesh.set_array(smoothed_evolution[frame_num,:,:].T)
    return cmesh,

anim = FuncAnimation(fig, anim_function, frames=np.arange(smoothed_evolution.shape[0]), interval=30)
anim.save("Network evolution smoothed.mp4")
[01/29/26 13:38:53] INFO     Animation.save using <class 'matplotlib.animation.FFMpegWriter'>     animation.py:1076
                    INFO     MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec         animation.py:319
                             rawvideo -s 2000x1000 -pix_fmt rgba -framerate 33.333333333333336                     
                             -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y 'Network                    
                             evolution smoothed.mp4'                                                               

Manim

NUM_LAYERS = 10
NEURONS_PER_LAYER = 10
NUM_FRAMES=2000

# --- Simulation
network = network_init()
network_states = run_simulation(network, NUM_FRAMES, sigma=1,spont_prob=0.01)
smoothed_states = smooth_activity(network_states)
def multilayered_graph(subset_sizes, edge_prob=0.35):
    ''' Generate a networkx multilayered graph with specied layer sizes '''
    extents = nx.utils.pairwise(itertools.accumulate([0] + subset_sizes))
    layers = [range(start, end) for start, end in extents]
    G = nx.Graph()
    for (i, layer) in enumerate(layers):
        G.add_nodes_from(layer, layer=i)
    for layer1, layer2 in nx.utils.pairwise(layers):
        all_edges = list(itertools.product(layer1, layer2))
        selected_edges = np.random.choice(range(len(all_edges)),  size=int(len(all_edges)*edge_prob), replace=False)
        for k in selected_edges:
            G.add_edge(*all_edges[k])
    return G
from manim import *
import networkx as nx
from scipy.interpolate import interp1d
import itertools
# --- Animation with Manim
class BranchingModelRearranging(Scene):
    def construct(self):

        # Set up coordinate systems
        shuffled_ax = Axes(x_range=(0,NUM_LAYERS), y_range=(0,NEURONS_PER_LAYER),x_length=7, y_length=7)
        layers_ax = Axes(x_range=(0,NUM_LAYERS), y_range=(0,NEURONS_PER_LAYER),x_length=13, y_length=7)

        # --- Mapping
        mapping = np.array(list(itertools.product(range(shuffled_ax.x_range[1]), range(shuffled_ax.y_range[1]))), dtype=object)
        layout_layered = {k: layers_ax.c2p(*mapping[k]) for k in range(NUM_LAYERS*NEURONS_PER_LAYER)}
        np.random.shuffle(mapping)
        layout_shuffle = {k: shuffled_ax.c2p(*mapping[k]) for k in range(NUM_LAYERS*NEURONS_PER_LAYER)}

        # Construct a graph object
        G = multilayered_graph(([NEURONS_PER_LAYER]*NUM_LAYERS))
        graph = Graph.from_networkx(G,layout=layout_shuffle,vertex_config={'radius': 0.2},
                                    edge_config={"stroke_width":0.5, "stroke_color":GRAY})

        # Interpolation function to animate the color of the nodes according to simulation data
        value_interp_function = interp1d(np.arange(smoothed_states.shape[0]),
                                         smoothed_states.reshape(smoothed_states.shape[0], NUM_LAYERS*NEURONS_PER_LAYER), axis=0)

        cmap = cmasher.get_sub_cmap(sns.color_palette("mako",as_cmap=True),0.2,1)


        def update_node_colors(graph):
            for k in range(len(G.nodes)):
                color =  cmap(value_interp_function(time_tracker.get_value())[k])
                graph[k].set_color(rgba_to_color(color))

        time_tracker = ValueTracker() # Progressing through simulation data
        graph.add_updater(update_node_colors)
        self.add(graph)

        # --- Animating (make sure that there is enough frames in the simulation data)
        FPS = 30
        PLAY_TIME_BEFORE_REARRANGING = 20
        PLAY_TIME_AFTER_REARRANGING = 5
        REARRANGING_TIME = 2

        def get_shuffle2layered_anims():
            return [graph[k].animate.move_to(layout_layered[k]) for k in range(len(G.nodes))]

        def animate_network(playing_time):
            self.play(time_tracker.animate.increment_value(int(playing_time*FPS)), run_time=playing_time, rate_func=linear)

        animate_network(PLAY_TIME_BEFORE_REARRANGING)


        self.play(*(get_shuffle2layered_anims() +
                    [time_tracker.animate.increment_value(int(REARRANGING_TIME*FPS))]),
                    run_time=REARRANGING_TIME, rate_func=linear)

        animate_network(PLAY_TIME_AFTER_REARRANGING)
        self.wait()